{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1043b220",
   "metadata": {},
   "source": [
    "# Detect Outliers with Cleanlab and PyTorch Image Models (timm)\n",
    "\n",
    "This quick tutorial shows how to detect outliers (out-of-distribution examples) in image data, using the [cifar10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset as an example. You can easily replace the image dataset + neural network used here with any other Pytorch dataset + neural network (e.g. to instead detect outliers in text data with minimal code changes). \n",
    "\n",
    "**Overview of what we'll do in this tutorial:**\n",
    "\n",
    "Detect outliers using `feature_embeddings`\n",
    "\n",
    "- Pre-process [cifar10](https://www.cs.toronto.edu/~kriz/cifar.html) into Pytorch datasets where `train_data` only contains images of animals and `test_data` contains images from all classes.\n",
    "\n",
    "- Use a pretrained neural network model from [timm](https://github.com/rwightman/pytorch-image-models) to extract feature embeddings of each image.\n",
    "\n",
    "- Use cleanlab to find naturally occurring outlier examples in the `train_data` (i.e. atypical images).\n",
    "\n",
    "- Find outlier examples in the `test_data` that do not stem from training data distribution (including out-of-distribution non-animal images).\n",
    "\n",
    "- Explore threshold selection for determining which images are outliers vs not.\n",
    "\n",
    "Detect outliers using `pred_probs` from a trained classifier\n",
    "\n",
    "- Adapt our [timm](https://github.com/rwightman/pytorch-image-models) network into a classifier by training an  additional output layer using the (in-distribution) training data.\n",
    "\n",
    "- Use cleanlab to find out-of-distribution examples in the dataset based on the probabilistic predictions of this classifier, as an alternative to relying on feature embeddings."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70016f64",
   "metadata": {},
   "source": [
    "<div class=\"alert alert-info\">\n",
    "Quickstart\n",
    "<br/>\n",
    "    \n",
    "Already have numeric **feature embeddings** for your data? Just run the code below to score how out-of-distribution  each example is.\n",
    "\n",
    "\n",
    "<div  class=markdown markdown=\"1\" style=\"background:white;margin:16px\">  \n",
    "    \n",
    "```python\n",
    "\n",
    "from cleanlab.outlier import OutOfDistribution\n",
    "    \n",
    "ood = OutOfDistribution()\n",
    "\n",
    "# To get outlier scores for train_data using feature matrix train_feature_embeddings\n",
    "ood_train_feature_scores = ood.fit_score(features=train_feature_embeddings)\n",
    "\n",
    "# To get outlier scores for additional test_data using feature matrix test_feature_embeddings\n",
    "ood_test_feature_scores = ood.score(features=test_feature_embeddings)\n",
    "    \n",
    "    \n",
    "```\n",
    "\n",
    "</div>\n",
    "    \n",
    "Already have `pred_probs` and `labels` for your classification dataset? Just run the code below to to score how out-of-distribution  each example is.\n",
    "\n",
    "\n",
    "<div  class=markdown markdown=\"1\" style=\"background:white;margin:16px\">  \n",
    "    \n",
    "```python\n",
    "\n",
    "from cleanlab.outlier import OutOfDistribution\n",
    "    \n",
    "ood = OutOfDistribution()\n",
    "\n",
    "# To get outlier scores for train_data using predicted class probabilities (from a trained classifier) and given class labels\n",
    "ood_train_predictions_scores = ood.fit_score(pred_probs=train_pred_probs, labels=labels)\n",
    "\n",
    "# To get outlier scores for additional test_data using predicted class probabilities\n",
    "ood_test_predictions_scores = ood.score(pred_probs=test_pred_probs)\n",
    "    \n",
    "    \n",
    "```\n",
    "    \n",
    "</div>\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "45cb0f90",
   "metadata": {},
   "source": [
    "## 1. Install the required dependencies\n",
    "You can use `pip` to install all packages required for this tutorial as follows:\n",
    "\n",
    "```ipython3\n",
    "!pip install matplotlib torch torchvision timm\n",
    "!pip install cleanlab\n",
    "...\n",
    "# Make sure to install the version corresponding to this tutorial\n",
    "# E.g. if viewing master branch documentation:\n",
    "#     !pip install git+https://github.com/cleanlab/cleanlab.git\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bbebfc8",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "# Package installation (hidden on docs website).\n",
    "# If running on Colab, may want to use GPU (select: Runtime > Change runtime type > Hardware accelerator > GPU)\n",
    "# Package versions we used: matplotlib==3.5.1, torch==2.1.2, torchvision==2.1.2, timm==0.6.12\n",
    "\n",
    "dependencies = [\"matplotlib\", \"torch\", \"torchvision\", \"timm\", \"cleanlab\"]\n",
    "\n",
    "if \"google.colab\" in str(get_ipython()):  # Check if it's running in Google Colab\n",
    "    %pip install cleanlab  # for colab\n",
    "    cmd = ' '.join([dep for dep in dependencies if dep != \"cleanlab\"])\n",
    "    %pip install $cmd\n",
    "else:\n",
    "    dependencies_test = [dependency.split('>')[0] if '>' in dependency \n",
    "                         else dependency.split('<')[0] if '<' in dependency \n",
    "                         else dependency.split('=')[0] for dependency in dependencies]\n",
    "    missing_dependencies = []\n",
    "    for dependency in dependencies_test:\n",
    "        try:\n",
    "            __import__(dependency)\n",
    "        except ImportError:\n",
    "            missing_dependencies.append(dependency)\n",
    "\n",
    "    if len(missing_dependencies) > 0:\n",
    "        print(\"Missing required dependencies:\")\n",
    "        print(*missing_dependencies, sep=\", \")\n",
    "        print(\"\\nPlease install them before running the rest of this notebook.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41733949",
   "metadata": {},
   "source": [
    "Let's first import the required packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4396f544",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from pylab import rcParams\n",
    "import torch\n",
    "import torchvision\n",
    "import timm\n",
    "from sklearn import preprocessing\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.ensemble import BaggingClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from cleanlab.outlier import OutOfDistribution\n",
    "from cleanlab.rank import find_top_issues"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3792f82e",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "# This (optional) cell is hidden from docs.cleanlab.ai \n",
    "# Set some seeds for reproducibility. \n",
    "\n",
    "SEED = 42\n",
    "np.random.seed(SEED)\n",
    "torch.manual_seed(SEED)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False\n",
    "torch.cuda.manual_seed_all(SEED)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be38283d",
   "metadata": {},
   "source": [
    "## 2. Pre-process the Cifar10 dataset\n",
    "\n",
    "Each image in the original [cifar10 dataset](https://www.cs.toronto.edu/~kriz/cifar.html) belongs to 1 of 10 classes: `[airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck]`. \n",
    "After loading the data and processing the images, we manually remove some classes from the training dataset thereby making images from these classes outliers in the test dataset. Here we to remove all classes that are not an animal, such that test images from the following classes would be out-of-distribution: `[airplane, automobile, ship, truck]`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd853a54",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load cifar10 images into tensors for training (rescales pixel values to [0,1] interval):\n",
    "transform_normalize = torchvision.transforms.Compose(\n",
    "    [torchvision.transforms.ToTensor(),])\n",
    "\n",
    "train_data = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
    "                                        download=True, transform=transform_normalize)\n",
    "test_data = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
    "                                       download=True, transform=transform_normalize)\n",
    "\n",
    "# Define in (animal) vs out (non-animal) of distribution labels\n",
    "animal_classes = [2,3,4,5,6,7]  # labels correspond to animal images\n",
    "non_animal_classes = [0,1,8,9]  # labels that correspond to non-animal images\n",
    "\n",
    "# Remove non-animal images from the training dataset\n",
    "animal_idxs = np.where(np.isin(train_data.targets, animal_classes))[0]\n",
    "\n",
    "# Only work with small subset of each dataset to speedup tutorial\n",
    "train_idxs = np.random.choice(animal_idxs, len(animal_idxs) // 6, replace=False)\n",
    "test_idxs = np.random.choice(range(len(test_data)), len(test_data) // 10, replace=False)\n",
    "\n",
    "train_data  = torch.utils.data.Subset(train_data, train_idxs)  # select subset of animal images for train_data\n",
    "test_data  = torch.utils.data.Subset(test_data, test_idxs)  # select subset of all images for test_data\n",
    "print('train_data length: %s' % (len(train_data)))\n",
    "print('test_data length: %s' % (len(test_data)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1be5ff2e",
   "metadata": {},
   "source": [
    "#### Visualize some of the training and test examples"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47514fe7",
   "metadata": {},
   "source": [
    "<details><summary>See the implementation of `plot_images` and `visualize_outliers` **(click to expand)**</summary>\n",
    "\n",
    "```python\n",
    "# Note: This pulldown content is for docs.cleanlab.ai, if running on local Jupyter or Colab, please ignore it.\n",
    "\n",
    "txt_classes = {0: 'airplane', \n",
    "              1: 'automobile', \n",
    "              2: 'bird',\n",
    "              3: 'cat', \n",
    "              4: 'deer', \n",
    "              5: 'dog', \n",
    "              6: 'frog', \n",
    "              7: 'horse', \n",
    "              8:'ship', \n",
    "              9:'truck'}\n",
    "\n",
    "def imshow(img):\n",
    "    npimg = img.numpy()\n",
    "    return np.transpose(npimg, (1, 2, 0))\n",
    "\n",
    "def plot_images(dataset, show_labels=False):\n",
    "    plt.rcParams[\"figure.figsize\"] = (9,7)\n",
    "    for i in range(15):\n",
    "        X,y = dataset[i]\n",
    "        ax = plt.subplot(3,5,i+1)\n",
    "        if show_labels:\n",
    "            ax.set_title(txt_classes[int(y)])\n",
    "        ax.imshow(imshow(X))\n",
    "        ax.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "def visualize_outliers(idxs, data):\n",
    "    data_subset = torch.utils.data.Subset(data, idxs)\n",
    "    plot_images(data_subset)\n",
    "    \n",
    "```\n",
    "</details>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b64e0aa",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "txt_classes = {0: 'airplane', \n",
    "              1: 'automobile', \n",
    "              2: 'bird',\n",
    "              3: 'cat', \n",
    "              4: 'deer', \n",
    "              5: 'dog', \n",
    "              6: 'frog', \n",
    "              7: 'horse', \n",
    "              8:'ship', \n",
    "              9:'truck'}\n",
    "\n",
    "def imshow(img):\n",
    "    npimg = img.numpy()\n",
    "    return np.transpose(npimg, (1, 2, 0))\n",
    "\n",
    "def plot_images(dataset, show_labels=False):\n",
    "    plt.rcParams[\"figure.figsize\"] = (9,7)\n",
    "    for i in range(15):\n",
    "        X,y = dataset[i]\n",
    "        ax = plt.subplot(3,5,i+1)\n",
    "        if show_labels:\n",
    "            ax.set_title(txt_classes[int(y)])\n",
    "        ax.imshow(imshow(X))\n",
    "        ax.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "def visualize_outliers(idxs, data):\n",
    "    data_subset = torch.utils.data.Subset(data, idxs)\n",
    "    plot_images(data_subset)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb28f354",
   "metadata": {},
   "source": [
    "Observe how there are only animals left in our `train_data`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a00aa3ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_images(train_data, show_labels=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df819e85",
   "metadata": {},
   "source": [
    "If we consider `train_data` to be representative of the typical data distribution, then non-animal images in `test_data` become outliers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41e5cb6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_images(test_data, show_labels=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "92caec8a",
   "metadata": {},
   "source": [
    "## 3. Use cleanlab and feature embeddings to find outliers in the data\n",
    "\n",
    "\n",
    "### Represent each image as a numeric feature embedding vector\n",
    "\n",
    "We can pass images through a neural network to generate vector embeddings via its hidden layer representation. Here we use a `resnet50` network from [timm](https://timm.fast.ai/), which has been pretrained on a large corpus of other images. Note that cleanlab's outlier detection can be applied to numeric feature embeddings generated from any model (or to the raw data features if they are already numeric vectors). Outlier detection works best with feature vectors whose values along each dimension are of a similar scale. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cf25354",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generates 2048-dimensional feature embeddings from images\n",
    "def embed_images(model, dataloader):\n",
    "    feature_embeddings = []\n",
    "    for data in dataloader:\n",
    "        images, labels = data\n",
    "        with torch.no_grad():\n",
    "            embeddings = model(images)\n",
    "            feature_embeddings.extend(embeddings.numpy())\n",
    "    feature_embeddings = np.array(feature_embeddings)\n",
    "    return feature_embeddings  # each row corresponds to embedding of a different image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85a58d41",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load pretrained neural network\n",
    "model = timm.create_model('resnet50', pretrained=True, num_classes=0)  # this is a pytorch network\n",
    "model.eval()  # eval mode disables training-time operators (like batch normalization)\n",
    "\n",
    "# Use dataloaders to stream images through the network\n",
    "batch_size = 50\n",
    "trainloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False)\n",
    "testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "# Generate feature embeddings\n",
    "train_feature_embeddings = embed_images(model, trainloader)\n",
    "print(f'Train embeddings pooled shape: {train_feature_embeddings.shape}')\n",
    "test_feature_embeddings = embed_images(model, testloader)\n",
    "print(f'Test embeddings pooled shape: {test_feature_embeddings.shape}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad857d69",
   "metadata": {},
   "source": [
    "### Scoring outliers in a given dataset (training data)\n",
    "\n",
    "Fitting cleanlab's ``OutOfDistribution`` class on ``feature_embeddings`` will find any naturally occurring outliers in a given dataset. These examples are atypical images that look strange or different from the majority of examples in the dataset. In our case, these correspond to odd-looking images of animals that do not resemble typical animals depicted in **cifar10**. This method produces a score in [0,1] for each example, where lower values correspond to more atypical examples (more likely out-of-distribution)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "feb0f519",
   "metadata": {},
   "outputs": [],
   "source": [
    "ood = OutOfDistribution()\n",
    "train_ood_features_scores = ood.fit_score(features=train_feature_embeddings)\n",
    "\n",
    "top_train_ood_features_idxs = find_top_issues(quality_scores=train_ood_features_scores, top=15)\n",
    "visualize_outliers(top_train_ood_features_idxs, train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "756333f7",
   "metadata": {},
   "source": [
    "For fun, let's see what cleanlab considers the least likely outliers in the dataset! We can do this by calling `find_top_issues` on the negated outlier scores. These examples look quite  homogeneous as each one is similar to many other training images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "089d5860",
   "metadata": {},
   "outputs": [],
   "source": [
    "bottom_train_ood_features_idxs = find_top_issues(quality_scores=-train_ood_features_scores, top=15)\n",
    "visualize_outliers(bottom_train_ood_features_idxs, train_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2521aefb",
   "metadata": {},
   "source": [
    "### Scoring outliers in additional test data\n",
    "\n",
    "Now suppose we want to find outlier images in some never before seen test data, in particular images unlikely to stem from the same distribution as the training data. We can use our already fitted `OutOfDistribution` estimator to score how typical each new test example would be under the training data distribution and visualize the most severe outliers in this additional data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78b1951c",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ood_features_scores = ood.score(features=test_feature_embeddings)\n",
    "\n",
    "top_ood_features_idxs = find_top_issues(test_ood_features_scores, top=15)\n",
    "visualize_outliers(top_ood_features_idxs, test_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c645c58",
   "metadata": {},
   "source": [
    "Many outliers identified in `test_data` depict (non-animal) classes not present in the training set. These non-animal images have very different feature embeddings than the animal-only images in the training data."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b5de6f6",
   "metadata": {},
   "source": [
    "### Deciding which test examples are outliers\n",
    "\n",
    "Given outlier scores, how do we determine how many of the top-ranked examples in ``test_data`` should be marked as outliers? \n",
    "\n",
    "Inevitably this has some true positive / false positive trade-off, so let's suppose we want to ensure around at most 5% false positives. We can use the 5th percentile of the distribution of `train_ood_features_scores` (assuming the training data are in-distribution examples without outliers) as a hard score threshold below which to consider a test example an outlier.\n",
    "\n",
    "Let's plot the 5th percentile of the training outlier score distribution (shown as red line)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e9dff81b",
   "metadata": {},
   "outputs": [],
   "source": [
    "fifth_percentile = np.percentile(train_ood_features_scores, 5)  # 5th percentile of the train_data distribution\n",
    "\n",
    "# Plot outlier_score distributions and the 5th percentile cutoff\n",
    "fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))\n",
    "plt_range = [min(train_ood_features_scores.min(),test_ood_features_scores.min()), \\\n",
    "             max(train_ood_features_scores.max(),test_ood_features_scores.max())]\n",
    "axes[0].hist(train_ood_features_scores, range=plt_range, bins=50)\n",
    "axes[0].set(title='train_outlier_scores distribution', ylabel='Frequency')\n",
    "axes[0].axvline(x=fifth_percentile, color='red', linewidth=2)\n",
    "axes[1].hist(test_ood_features_scores, range=plt_range, bins=50)\n",
    "axes[1].set(title='test_outlier_scores distribution', ylabel='Frequency')\n",
    "axes[1].axvline(x=fifth_percentile, color='red', linewidth=2)\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74c39ab1",
   "metadata": {},
   "source": [
    "All test examples whose `test_ood_features_scores` fall left of the red line will be marked as an outlier.\n",
    "\n",
    "Let's plot the least-certain outliers of our `test_data` (i.e. 15 images with outlier scores right along the threshold). These are the images immediately to the left of that cutoff threshold (red line). The majority of them are still truly out-of-distribution non-animal images, but there are a few atypical-looking animals that are now erroneously identified as outliers as well."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "616769f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "sorted_idxs = test_ood_features_scores.argsort()\n",
    "ood_features_scores = test_ood_features_scores[sorted_idxs]\n",
    "ood_features_indices = sorted_idxs[ood_features_scores < fifth_percentile]  # Images in test data flagged as outliers\n",
    "\n",
    "visualize_outliers(ood_features_indices[::-1], test_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb4c0a06",
   "metadata": {},
   "source": [
    "### How does cleanlab detect outliers from feature values?\n",
    "\n",
    "Outlier scores are defined relative to the average distance (computed over feature values) between each example and its K nearest neighbors in the training data. Such scores have been found to be particularly effective for out-of-distribution detection, see this paper for more details:\n",
    "\n",
    "[Back to the Basics: Revisiting Out-of-Distribution Detection Baselines](https://arxiv.org/abs/2207.03061)\n",
    "\n",
    "\n",
    "Internally, cleanlab uses the `sklearn.neighbors.NearestNeighbor` class (with *cosine* distance) to find the K nearest neighbors, but you can easily use [another KNN estimator](https://github.com/cleanlab/examples/blob/master/outlier_detection_cifar10/outlier_detection_cifar10.ipynb) with cleanlab's `OutOfDistribution` class."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "937c7e97",
   "metadata": {},
   "source": [
    "## 4. Use cleanlab and `pred_probs` to find outliers in the data\n",
    "\n",
    "We sometimes wish to find outliers in classification datasets for which we do not have meaningful numeric feature representations. In this case, cleanlab can detect unusual examples in the data solely using predicted probabilities from a trained classifier.\n",
    "\n",
    "To get `pred_probs` here, a Logistic Regression classifier is fit on the already generated `train_feature_embeddings` (from our pretrained timm network) and the given label for each training image. We use a simple classifier here to quickly generate `pred_probs`, but in practice [fine-tuning the entire neural network for classification](https://github.com/cleanlab/examples/blob/master/outlier_detection_cifar10/outlier_detection_cifar10.ipynb) will be more effective (our approach here is equivalent to only training an extra output layer appended  on top of the pretrained network)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40fed4ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess data\n",
    "train_labels = np.array(train_data.dataset.targets)[train_data.indices]\n",
    "train_labels = np.unique(train_labels, return_inverse=True)[1]  # MAKE SURE to zero index training labels for sklearn\n",
    "test_labels = np.array(test_data.dataset.targets)[test_data.indices]\n",
    "\n",
    "scaler = preprocessing.StandardScaler().fit(train_feature_embeddings)\n",
    "train_feature_embeddings_scaled = scaler.transform(train_feature_embeddings)\n",
    "test_feature_embeddings_scaled = scaler.transform(test_feature_embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89f9db72",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Our classifier employs bagging to better account for epistemic uncertainty \n",
    "model = BaggingClassifier(LogisticRegression(max_iter=500), random_state=1, n_jobs=-1)\n",
    "model.fit(train_feature_embeddings_scaled, train_labels)\n",
    "\n",
    "train_pred_probs = model.predict_proba(train_feature_embeddings_scaled)\n",
    "train_pred_labels = train_pred_probs.argmax(1)\n",
    "accuracy = np.mean(train_pred_labels == train_labels)\n",
    "print(f\"Model accuracy on held-out train_data {accuracy}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03e3f7b7",
   "metadata": {},
   "source": [
    "We can use these `pred_probs` to again compute out-of-distribution scores for each image in our dataset using cleanlab's `OutOfDistribution` class."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "874c885a",
   "metadata": {},
   "outputs": [],
   "source": [
    "ood = OutOfDistribution()\n",
    "train_ood_predictions_scores = ood.fit_score(pred_probs=train_pred_probs, labels=train_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dcff8e5a",
   "metadata": {},
   "source": [
    "We can repeat this for additional test data, to identify test images that do not stem from the training data distribution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e110fc4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_pred_probs = model.predict_proba(test_feature_embeddings_scaled)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85b60cbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_ood_predictions_scores = ood.score(pred_probs=test_pred_probs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "702aa162",
   "metadata": {},
   "source": [
    "Detecting outliers based on feature embeddings can be done for arbitrary unlabeled datasets, but requires a meaningful numerical representation of the data.  Detecting outliers based on predicted probabilities applies mainly for labeled classification datasets, but can be done with any effective classifier. The effectiveness of the latter approach depends on: how much auxiliary information captured in the feature values is lost in the predicted probabilities (determined by the particular set of labels in the classification task), the accuracy of our classifier, and how properly its predictions reflect epistemic uncertainty. Read more about it [here](https://pub.towardsai.net/a-simple-adjustment-improves-out-of-distribution-detection-for-any-classifier-5e96bbb2d627)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03a5c870",
   "metadata": {},
   "source": [
    "## Spending too much time on data quality?\n",
    "\n",
    "Using this open-source package effectively can require significant ML expertise and experimentation, plus handling detected data issues can be cumbersome.\n",
    "\n",
    "That’s why we built [Cleanlab Studio](https://cleanlab.ai/blog/data-centric-ai/) -- an automated platform to find **and fix** issues in your dataset, 100x faster and more accurately.  Cleanlab Studio automatically runs optimized data quality algorithms from this package on top of cutting-edge AutoML & Foundation models fit to your data, and helps you fix detected issues via a smart data correction interface. [Try it](https://cleanlab.ai/) for free!\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img src=\"https://raw.githubusercontent.com/cleanlab/assets/master/cleanlab/ml-with-cleanlab-studio.png\" alt=\"The modern AI pipeline automated with Cleanlab Studio\">\n",
    "</p>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17f96fa6",
   "metadata": {
    "nbsphinx": "hidden"
   },
   "outputs": [],
   "source": [
    "# Note: This cell is only for docs.cleanlab.ai, if running on local Jupyter or Colab, please ignore it.\n",
    "\n",
    "# Verify the top identified test outliers data are mostly non-animal images\n",
    "top_ood_features_subset = torch.utils.data.Subset(test_data, top_ood_features_idxs)\n",
    "num_animals = len([i for i in range(len(top_ood_features_subset)) if top_ood_features_subset[i][1] in animal_classes])\n",
    "non_animal_frac = 1 - (num_animals / len(top_ood_features_subset))\n",
    "if non_animal_frac < 0.81:\n",
    "    raise Exception(f\"Not enough non-animal images amongst top-ranked outliers in test_data, only: {non_animal_frac}\")\n",
    "\n",
    "top_ood_predictions_idxs = (test_ood_predictions_scores).argsort()[:15]\n",
    "top_ood_predictions_subset = torch.utils.data.Subset(test_data, top_ood_predictions_idxs)\n",
    "num_animals = len([i for i in range(len(top_ood_predictions_subset)) if top_ood_predictions_subset[i][1] in animal_classes])\n",
    "non_animal_frac = 1 - (num_animals / len(top_ood_predictions_subset))\n",
    "if non_animal_frac < 0.50:\n",
    "    raise Exception(f\"Not enough non-animal images amongst top-ranked ood datapoints in test_data, only: {non_animal_frac}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
